library(gmm)
source("joint-iv-utils.R")
cc.prob <- 0.4
cprobs <- c(cc.prob, rep((1-cc.prob)/8, 8))
N <- 1000
sims <- 1000
hold11 <- rep(NA, sims)
hold10 <- rep(NA, sims)
hold01 <- rep(NA, sims)
hold00 <- rep(NA, sims)
ghold <- rep(NA, sims)
laje.hold <- rep(NA, sims)
laje.se <- rep(NA, sims)
gmm.hold <- rep(NA, sims)
gmm.se <- rep(NA, sims)
num.hold <- rep(NA, sims)


gfun <- function(theta, x) {
  d1 <- x[,1]
  d2 <- x[,2]
  z1 <- x[,3]
  z2 <- x[,4]
  y <- x[,5]
  z11 <- z1 * z2
  z10 <- z1 * (1 - z2)
  z01 <- (1 - z1) * z2
  z00 <- (1 - z1) * (1 - z2)
  d11 <- d1 * d2
  d10 <- d1 * (1 - d2)
  d01 <- (1 - d1) * d2
  d00 <- (1 - d1) * (1 - d2)

  rho.cc <- theta[1]
  rho.ca <- theta[2]
  rho.cn <- theta[3]
  rho.ac <- theta[4]
  rho.aa <- theta[5]
  rho.an <- theta[6]
  rho.nc <- theta[7]
  rho.na <- theta[8]
  psi.11.cc <- theta[9]
  psi.11.ca <- theta[10]
  psi.11.ac <- theta[11]
  psi.11.aa <- theta[12]
  psi.10.cc <- theta[13]
  psi.10.an <- theta[14]
  psi.10.ac <- theta[15]
  psi.10.cn <- theta[16]
  psi.01.cc <- theta[17]
  psi.01.nc <- theta[18]
  psi.01.na <- theta[19]
  psi.01.ca <- theta[20]
  psi.00.cc <- theta[21]
  psi.00.cn <- theta[22]
  psi.00.nc <- theta[23]
  psi.00.nn <- theta[24]
  rho.nn <- (1-sum(theta[1:8]))

  g11.11 <- z11 * (d11 - rho.cc - rho.aa - rho.ac - rho.ca)
  g10.11 <- z11 * (d10 - rho.cn - rho.an)
  g01.11 <- z11 * (d01 - rho.nc - rho.na)
  ## g00.11 <- z11 * (d00 - rho.nn)

  g11.01 <- z01 * (d11 - rho.aa - rho.ac)
  g10.01 <- z01 * (d10 - rho.an)
  g01.01 <- z01 * (d01 - rho.cc - rho.nc - rho.na - rho.ca)
  ## g00.01 <- z01 * (d00 - rho.cn - rho.nn)

  g11.10 <- z10 * (d11 - rho.aa - rho.ca)
  g10.10 <- z10 * (d10 - rho.cc - rho.ac - rho.cn - rho.an)
  g01.10 <- z10 * (d01 - rho.na)
  ## g00.10 <- z10 * (d00 - rho.nc - rho.nn)

  g11.00 <- z00 * (d11 - rho.aa)
  g10.00 <- z00 * (d10 - rho.ac - rho.an)
  g01.00 <- z00 * (d01 - rho.ca - rho.na)
  ## g00.00 <- z00 * (d00 - rho.cc - rho.cn - rho.nc - rho.nn)

  s11.11 <- z11 * (y * d11 - psi.11.cc * rho.cc - psi.11.ca * rho.ca - psi.11.ac * rho.ac - psi.11.aa * rho.aa)
  s10.11 <- z11 * (y * d10 - psi.10.cn * rho.cn - psi.10.an * rho.an)
  s01.11 <- z11 * (y * d01 - psi.01.nc * rho.nc - psi.01.na * rho.na)
  s00.11 <- z11 * (y * d00 - psi.00.nn * rho.nn)

  s11.10 <- z10 * (y * d11 - psi.11.aa * rho.aa - psi.11.ca * rho.ca)
  s10.10 <- z10 * (y * d10 - psi.10.cc * rho.cc - psi.10.ac * rho.ac - psi.10.cn * rho.cn - psi.10.an * rho.an)
  s01.10 <- z10 * (y * d01 - psi.01.na * rho.na)
  s00.10 <- z10 * (y * d00 - psi.00.nn * rho.nn - psi.00.nc * rho.nc)

  s11.01 <- z01 * (y * d11 - psi.11.aa * rho.aa - psi.11.ac * rho.ac)
  s10.01 <- z01 * (y * d10 - psi.10.an * rho.an)
  s01.01 <- z01 * (y * d01 - psi.01.cc * rho.cc - psi.01.nc * rho.nc - psi.01.na * rho.na - psi.01.ca * rho.ca)
  s00.01 <- z01 * (y * d00 - psi.00.cn * rho.cn - psi.00.nn * rho.nn)

  s11.00 <- z00 * (y * d11 - psi.11.aa * rho.aa)
  s10.00 <- z00 * (y * d10 - psi.10.an * rho.an - psi.10.ac * rho.ac)
  s01.00 <- z00 * (y * d01 - psi.01.na * rho.na - psi.01.ca * rho.ca)
  s00.00 <- z00 * (y * d00 - psi.00.cc * rho.cc - psi.00.nn * rho.nn - psi.00.cn * rho.cn - psi.00.nc * rho.nc)


  out <- cbind(g11.11, g10.11, g01.11, g11.01, g10.01, g01.01,
               g11.10, g10.10, g01.10, g11.00, g10.00, g01.00,
               s11.11, s10.11, s01.11, s00.11, s11.10, s10.10,
               s01.10, s00.10, s11.01, s10.01, s01.01, s00.01,
               s11.00, s10.00, s01.00, s00.00)
  out
}

local.est <- function(y, d1, d2, z1, z2) {
  z11 <- z1 * z2
  z10 <- z1 * (1 - z2)
  z01 <- (1 - z1) * z2
  z00 <- (1 - z1) * (1 - z2)
  d11 <- d1 * d2
  d10 <- d1 * (1 - d2)
  d01 <- (1 - d1) * d2
  d00 <- (1 - d1) * (1 - d2)

  rho <- rep(NA, times = 9)
  names(rho) <- c("cc", "ca", "cn", "ac", "aa", "an", "nc", "na", "nn")
  rho["cc"] <- mean(d11[z11 == 1]) - mean(d11[z10 == 1]) - mean(d11[z01 == 1]) + mean(d11[z00 == 1])
  rho["ca"] <- mean(d11[z10 == 1]) - mean(d11[z00 == 1])
  rho["cn"] <- mean(d00[z01 == 1]) - mean(d00[z11 == 1])
  rho["ac"] <- mean(d11[z01 == 1]) - mean(d11[z00 == 1])
  rho["aa"] <- mean(d11[z00 == 1])
  rho["an"] <- mean(d10[z01 == 1])
  rho["nc"] <- mean(d00[z10 == 1]) - mean(d00[z11 == 1])
  rho["na"] <- mean(d01[z10 == 1])
  rho["nn"] <- mean(d00[z11 == 1])

  rho[rho <= 0] <- 0.01
  rho <- rho/sum(rho)

  psi.11.aa <- mean((y * d11)[z00 == 1])/rho["aa"]
  psi.10.an <- mean((y * d10)[z01 == 1])/rho["an"]
  psi.01.na <- mean((y * d01)[z10 == 1])/rho["na"]
  psi.00.nn <- mean((y * d00)[z11 == 1])/rho["nn"]
  psi.11.ac <- (mean((y * d11)[z01 == 1]) - mean((y * d11)[z00 == 1]))/rho["ac"]
  psi.11.ca <- (mean((y * d11)[z10 == 1]) - mean((y * d11)[z00 == 1]))/rho["ca"]
  psi.11.cc <- (mean((y * d11)[z11 == 1]) - mean((y * d11)[z10 == 1]) - mean((y * d11)[z01 == 1]) + mean((y * d11)[z00 == 1]))/rho["cc"]
  psi.10.cc <- (mean((y * d10)[z10 == 1]) - mean((y * d10)[z11 == 1]) - mean((y * d10)[z00 == 1]) + mean((y * d10)[z01 == 1]))/rho["cc"]
  psi.01.cc <- (mean((y * d01)[z01 == 1]) - mean((y * d01)[z11 == 1]) - mean((y * d01)[z00 == 1]) + mean((y * d01)[z10 == 1]))/rho["cc"]
  psi.00.cc <- (mean((y * d00)[z00 == 1]) - mean((y * d00)[z01 == 1]) - mean((y * d00)[z10 == 1]) + mean((y * d00)[z11 == 1]))/rho["cc"]
  psi.10.ac <- (mean((y * d10)[z00 == 1]) - mean((y * d10)[z01 == 1]))/rho["ac"]
  psi.10.cn <- (mean((y * d10)[z11 == 1]) - mean((y * d10)[z01 == 1]))/rho["cn"]
  psi.01.nc <- (mean((y * d01)[z11 == 1]) - mean((y * d01)[z10 == 1]))/rho["nc"]
  psi.01.ca <- (mean((y * d01)[z00 == 1]) - mean((y * d01)[z10 == 1]))/rho["ca"]
  psi.00.cn <- (mean((y * d00)[z01 == 1]) - mean((y * d00)[z11 == 1]))/rho["cn"]
  psi.00.nc <- (mean((y * d00)[z10 == 1]) - mean((y * d00)[z11 == 1]))/rho["nc"]

  out <- c(rho[-9], psi.11.cc, psi.11.ca, psi.11.ac, psi.11.aa, psi.10.cc, psi.10.an, psi.10.ac, psi.10.cn, psi.01.cc, psi.01.nc, psi.01.na, psi.01.ca, psi.00.cc, psi.00.cn, psi.00.nc, psi.00.nn)
}


set.seed(02143)
for (i in 1:sims) {
  if ((i %% 10) == 0) cat("sim ", i, " of ", sims, "\n")
  c1 <- t(rmultinom(N, 1, cprobs))
  colnames(c1) <- c("cc", "ca", "cn", "ac", "aa", "an", "nc", "na", "nn")

  z1 <- sample(rep(c(0,1), N/2))
  z2 <- sample(rep(c(0,1), N/2))

  d1 <- rowSums(c1[,c("cc", "ca", "cn")])*z1 + rowSums(c1[,c("aa", "ac", "an")])
  d2 <- rowSums(c1[,c("cc", "ac", "nc")])*z2 + rowSums(c1[,c("aa", "ca", "na")])


  beta.0 <- rnorm(N, rowSums(t(c(1, 0.2, 0.1, -0.1, 0.2, 0.2, -0.4, 0.2, 0.7)*t(c1))), 1)
  beta.d1 <- rnorm(N, rowSums(t(c(1, 0.2, 0.1, -0.1, 0.2, 0.2, -0.4, 0.2, 0.7)*t(c1))), 1)
  beta.d2 <- rnorm(N, rowSums(t(c(1, 0.2, 0.1, -0.1, 0.2, 0.2, -0.4, 0.2, 0.7)*t(c1))), 1)
  beta.int <- rnorm(N, rowSums(t(c(1, 0.2, 0.1, -0.1, 0.2, 0.2, -0.4, 0.2, 0.7)*t(c1))), 1)
  epsilon <- rnorm(N,0,.5)

  y00 <- beta.0 + epsilon
  y10 <- beta.0 + beta.d1 + epsilon
  y01 <- beta.0 + beta.d2 + epsilon
  y11 <- beta.0 + beta.d1 + beta.d2 + beta.int + epsilon

  y <- y00 + d1 * (y10-y00) + d2 * (y01-y00) + d1 * d2 * ((y11 - y01) - (y10 - y00))

  f00 <- (1-d1)*(1-d2)
  f10 <- d1 * (1 - d2)
  f01 <- (1 - d1) * d2
  f11 <- d1 * d2
  s11 <- d1 * d2 * y
  s00 <- (1-d1) * (1-d2) * y
  z11 <- z1 == 1 & z2 == 1
  z10 <- z1 == 1 & z2 == 0
  z01 <- z1 == 0 & z2 == 1
  z00 <- z1 == 0 & z2 == 0

  t0 <- local.est(y, d1, d2, z1, z2)
  gout <- gmm(gfun, x = cbind(d1, d2, z1, z2, y), t0 = t0, wmatrix = "optimal")
  ghold[i] <- gout$coefficients[1]

  num.hold[i] <- mean(s11[z11]) - mean(s00[z11]) - (mean(s11[z10]) - mean(s00[z10])) - (mean(s11[z01]) - mean(s00[z01])) + (mean(s11[z00]) - mean(s00[z00]))
  out <- joint.iv(y, d1, d2, z1, z2)
  laje.hold[i] <- out$tau["laje"]
  laje.se[i] <- out$se["laje"]
  gmm.hold[i] <- gout$coefficients[9]-gout$coefficients[21]
  gmm.se[i] <- sqrt(vcov(gout)[9,9] + vcov(gout)[21,21] - 2 * vcov(gout)[9,21])
}


mean(ghold)
sd(ghold)
mean(num.hold)
sd(num.hold)

mean(laje.hold)
sd(laje.hold)
mean(laje.se)
mean(gmm.hold)
sd(gmm.hold)
mean(gmm.se)


cairo_pdf(file = "../figs/gmm-sims.pdf", width = 5, height = 4, pointsize = 9, family = "Minion Pro", bg = "transparent")
plot(density(gmm.hold), lwd = 2, col = "black", bty = "n", las = 1, main = "", xlab = "LAJE Estimate")
lines(density(laje.hold), lwd = 2, col = "indianred")
abline(v = 3, col = "grey")
dev.off()

sqrt(mean((laje.hold-3)^2))
sqrt(mean((gmm.hold-3)^2))
